{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Scalable GP Classification in 1D (w/ SVGP)\n",
    "\n",
    "This example shows how to use grid interpolation based variational classification with an `AbstractVariationalGP` using a `VariationalStrategy` module while learning the inducing point locations. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/gpleiss/anaconda3/envs/gpytorch/lib/python3.7/site-packages/matplotlib/__init__.py:999: UserWarning: Duplicate key in file \"/home/gpleiss/.dotfiles/matplotlib/matplotlibrc\", line #57\n",
      "  (fname, cnt))\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "import torch\n",
    "import gpytorch\n",
    "from matplotlib import pyplot as plt\n",
    "from math import exp\n",
    "\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x = torch.linspace(0, 1, 260)\n",
    "train_y = torch.cos(train_x * (2 * math.pi)) + 0.1 * torch.randn(260)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gpytorch.models import AbstractVariationalGP\n",
    "from gpytorch.variational import CholeskyVariationalDistribution\n",
    "from gpytorch.variational import VariationalStrategy\n",
    "class SVGPRegressionModel(AbstractVariationalGP):\n",
    "    def __init__(self, inducing_points):\n",
    "        variational_distribution = CholeskyVariationalDistribution(inducing_points.size(-1))\n",
    "        variational_strategy = VariationalStrategy(self,\n",
    "                                                   inducing_points,\n",
    "                                                   variational_distribution,\n",
    "                                                   learn_inducing_locations=True)\n",
    "        super(SVGPRegressionModel, self).__init__(variational_strategy)\n",
    "        self.mean_module = gpytorch.means.ConstantMean()\n",
    "        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n",
    "        \n",
    "    def forward(self,x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n",
    "        return latent_pred\n",
    "\n",
    "\n",
    "# We'll initialize the inducing points to evenly span the space of train_x\n",
    "inducing_points = torch.linspace(0, 1, 25)\n",
    "model = SVGPRegressionModel(inducing_points)\n",
    "likelihood = gpytorch.likelihoods.GaussianLikelihood()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 1 - Loss: 1.495 [-1.487, 0.008, 0.000]\n",
      "Iter 51 - Loss: 0.926 [-0.873, 0.054, 0.000]\n",
      "Iter 101 - Loss: 0.573 [-0.447, 0.127, 0.000]\n",
      "Iter 151 - Loss: 0.380 [-0.116, 0.264, 0.000]\n",
      "CPU times: user 18.9 s, sys: 19.9 s, total: 38.9 s\n",
      "Wall time: 5.57 s\n"
     ]
    }
   ],
   "source": [
    "from gpytorch.mlls.variational_elbo import VariationalELBO\n",
    "\n",
    "# Find optimal model hyperparameters\n",
    "model.train()\n",
    "likelihood.train()\n",
    "\n",
    "# Use the adam optimizer\n",
    "optimizer = torch.optim.Adam([\n",
    "    {'params': model.parameters()},\n",
    "    {'params': likelihood.parameters()}\n",
    "], lr=0.01)\n",
    "\n",
    "# \"Loss\" for GPs - the marginal log likelihood\n",
    "# n_data refers to the amount of training data\n",
    "mll = VariationalELBO(likelihood, model, train_y.size(0), combine_terms=False)\n",
    "\n",
    "def train():\n",
    "    num_iter = 200\n",
    "    for i in range(num_iter):\n",
    "        optimizer.zero_grad()\n",
    "        output = model(train_x)\n",
    "        # Calc loss and backprop gradients\n",
    "        log_lik, kl_div, log_prior = mll(output, train_y)\n",
    "        loss = -(log_lik - kl_div + log_prior)\n",
    "        loss.backward()\n",
    "        if i % 50 == 0:\n",
    "            print('Iter %d - Loss: %.3f [%.3f, %.3f, %.3f]' % (i + 1, loss.item(), log_lik.item(), kl_div.item(), log_prior.item()))\n",
    "        optimizer.step()\n",
    "        \n",
    "# Get clock time\n",
    "%time train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQAAAADDCAYAAABtec/IAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAIABJREFUeJztnXl4VOXdsO8zM5nJPtlDSCAk7HsCoyJuQNAKVqsUi0s/bWu1b1t9bUuxULVqXatFa92+KvpqP4svilgXUJFgq1ZUJiQiCGhI2Mk6meyZ9Xx/nJlhkkyS2ZJMmOe+rrmSOesz55znd57nt0qyLCMQCKIT1XA3QCAQDB9CAAgEUYwQAAJBFCMEgEAQxQgBIBBEMZpQD2AwGBa7/r3QaDT+LtTjCQSCoSOkEYCr819pNBq3AXMMBsOc8DRLIBAMBVK4/AAMBsNBo9E4PiwHEwgEQ0JYdAAGg+E24GfhOJZAIBg6wjkCeA240Wg0mn2tX716tXA5FAiGiYceekjytTwkJaB7zm80GncBVcBNwMN9bX/PPfcMeMy6ujqysrJCadagE+ltjPT2QeS3MdLbB/638a677upzXahTgMVAmuv/FBQhIBAIRgihCoBngUKDwXATgNFo3Bh6kwQCwVAR0hTANd9/NkxtEUQJdrud1tZWWltbidRoVKfTSUtLy3A3o196tlGSJHQ6HaNGjUKj8a9rh+wIJBAESk1NDXq9nvT0dCTJp25q2LHZbMTExAx3M/qlZxtlWcZsNlNTU0NeXp5fxxCuwIIhx2KxkJycHLGdf6QiSRIpKSlYLBa/9xECQDDkyLIcEZ2/vLyc8vLyQT+P2Wxm06ZNg34eUIRAINMqIQAEEc3JkydZvHgxNTU1QR+jvLycdevWUVpayrp166iqUoxVer2ejRsHX2+dkpLi8zzl5eVMnTqVTZs2sWnTJtauXetpmy/6WxcsQgcgiGgefPBBPv30Ux544AH++te/Bry/2WzmkUceYf369Z5l11xzDevXryctLa2fPcNLampqr2XFxcUUFBSwbNkyz7KlS5eyZcuWXttWVVXx/PPPc//994e1XUIACCKSlJQUurq6PN+fffZZnn32WWJjYzGbfTqb+mTjxo0sWrSo27LU1FRKS0uZO3cu5eXllJaWUlFRwQ033EBZWRkAZWVlLF++nO3bt5OWlkZBQQHV1dVs3LiRgoICJk+ezHvvvcf69ev55S9/ycqVKwG6bV9QUMDzzz9PUVERu3bt8vt3u9/027dvB2DRokVUVFRQXV1NeXk5er2e7du343A4uPDCCyksLPT7evRETAEEEcm+fftYsWIFcXFxAMTFxXHVVVexf//+gI/V3Nzc57ri4mJKSkooKiri+eefp6Kigu3bt7Nw4ULuuOMO5s6d6+n8ixYtIjU1lfvvv5/rrrvOc4xly5ZRWFjYa/vbb7+dK664gpKSEgoKCgJqc2FhIWlpaaSlpfHGG2+waNEiCgoKKC4u7rUuFIQAEEQkOTk5JCcnY7FYiI2N9VgORo0aFdBxFi1a5Hmru6murqakpKTbMvd04IorruCGG25g7dq1WK1W9Ho9xcXFnlFESkpKt2OvXbuWuXPnepb13D5QzGYzhYWFrF27Fr1eT1FRkWc5KFMB97rZs2d3WxcMYgogiFjq6uq48cYbueGGG3j++eeDUgQWFhayatUq1q1bR0FBARUVFTz55JOe9WazudsUwD1kX7hwIRdeeCHPP/+85+3rHoKbzWZSUlJYvnw5t99+u0co3Hfffd22X7lyJW+88QZFRUWefYuLiz3nLi8vp7q62mMhqK6u9rTNfb7m5maqqqpoamrCbDZTXV3tWWcymaiqqqK6urrbcQMhbNGAA7F69WpZBAMNDZHevsrKSvLz8yPa0WYkOgK5qaysZMKECZ7vd911V5/RgGIKIBBEMUIACARRjBAAAkEUIwSAQBDFCAEgEEQxQgAIBFGMEAACQRQjBIDgtKa8vJyzzz67W9hvVVVVr2XRSjhKg93k+ne8KA0mCJTYWF1YjtPV5TsJRnFxsccT8KmnngKU2AC3X320E47SYNuMRqM7OejigfYRCIYavV7f57qqqirWrVvHpk2bKC8v93x/4YUXqKqqorS0lKVLl1JaWsrtt98+hK0eGkKdAhSipAYHJSV48HGJgqikq8sSls9ALFu2jHXr1vXyx+8Zwdcz0q6kpISUlBRKSkpCCrqJVELNCuydEXgOsCG05ggEg0NJSQnXXHNNt8g9N3q9nsLCQgoKCli7di1FRUWMGTOGI0eOYDabfSbzOF0ISzSgq0LQLleFoD6pq6sb8FgjQcpGehsjvX1OpxOHwzEk56qoqOC5555jzJgxFBUVkZeXx86dOykvL2fnzp3cfffdPPvssyxatIhx48YxduxYKisraWhooLKykrfffpuqqiq++eYbqqqq2LlzpydEd7jp6xo6nU6/+hqEKRrQYDDcZjQa+ywJBiIacCiJ9PaJaMDwEBHRgAaD4SZ35xdKQIFgZBEOK8CfDAbDQYPB0BSmNgkEgiEiVCXgNuD01ZAIBKc5whNQIIhihAAQCKKYiBMAxqOtHG7swOGMzKqxAgEMbbmvwSTisgIfqO/kYMtJYtQqxqTGUZiRQEFGPLEx6uFummCQeOLDgyHtf8vC8f2uLy8vp6yszJObf/v27X5X2HE7BlVUVHiKf8Cpcl/eVX2GAlmWsTuVTzg6b8QJADc2h5OqhnaqGtqRJIm8lFgKMxMYn5FAgi5imy2IMHyVBvP3zW02mzGZTJSUlPgsIzYUHoJOp4zV4cTmcGJzKB3fXVw1NTb0AfyI6EmyLHO0qZOjTZ189G0jY9PimJKdRGFGPBp1xM1iBBHExo0be7n/rly5kqqqKrZv305BQQHNzc3o9XrWrl3LypUr2b59O3fffTdlZWVUV1dTWlrKHXfcwY4dOzCbzb3KfbmP5S4JZjKZuh3r/vvv99T2c7elqKio2z7u+ARZlrE5ZKx2JxaHE4dDRmbwpsMjrvfIsszhxg7e/7qWFz49wocH6mlssw53swQjCHcZr5/+9KeUlJSwceNGn0E/7iChkpIS5syZA+Cz3FfPgCJfx1q7di033HADy5YtY9GiRb32sdqdtHbZaGy30tRhpd1qx+5wDmrnhxEyAugLi93BnhMt7DnRQl5qHLNy9RSkx6NSDX/teUFksHz5cn7xi190W1ZaWgrgqfBjNptDDvrxDigC39MD9zTCXUkoOVlPzpix6LNzaeoYnpfYiBYA3hxr6uRYUyfJsTHMzU9h6qgk1EIQRD0pKSndSoM1NzdTVFTEfffdx8aNG0lLS2PZsmVUV1dTXV3tKbVVUVFBS0uLpxTYrl27KC8v91nuq2dJsJ7Hcu/3yCOPMHfuXMaOzWf1nXfzxP/9G+edv5Cx+eNISEwelusTcaXBHntvDxpdXMjnS9JpKB6bwvScpLDrCSI92CbS2xetwUA2h5MOqwOL3Umo/c6tBAw1GOi0GQH0pNVi56NvG9h1xMyZ41KZOipJTA0Ew4Ld4aTd6qDLNjQh0IEw4pSAgdJmsbP9QD2v7DxGVUP7cDdHEEU4nDKtXTZMHbawdH6nw079sUM4HfYwtE7htBcAbkwdVjZ/VcOm8hPCajDMSJIU8hA4kpFlmXaLHVO7lQ6rI2y/taWxHktnOy2N9f2eW5L8H+metlOAvjhu7uQV4zFm5SYzryANrSZqZGDEoNPpaGlpIT09PaCHdSRgtTtptSgmvHBxvPLrbkKkrdlEW7OJ4ypVt/yGsixjNpvR6fzPtBx1AgCUC/XlsWa+rWvnnPFpTBmVNNxNiipGjRrF4cOHaW5ujtiRgNPpRKXy/+UgAxa7E5s9fB3f3Y729g5UKhU2q8Xzho/RxZKemkJlZaVnW0mS0Ol0jBo1yu/jR6UAcNNhtfPBvjq+qWtj4aRMkmKj+nIMGRqNhqSkpIi2VARiSTnc2EHp/nrarXYgvDErm578I59teZWsMYXUHa1CHaPFYbMyb+kKNjzzcMjXMKLHvy2mep5edT0tpr7nPOHgcGMH6784yu7jkftGEkQeVruTDw/U89buk67OHz7WXFbMqiXT2bF5A7IsU3vkILIs43TYidenUvHvLX4n/uyPiBYA29Y/Q/UeI3+55cpuQmAwBIPV4eTf3zTwRsVJWjptYTuu4PTkZHMX/2s8xp4TLWE/doupnqyx44lPSkGjPTWf12h1zJi/mHazic62FpYuXUpNTU1I5wqLAHClBQ8b3tIPoNVUz73XLmDNZcW0mOr5yy1XcmhvGVteeNQjCHwJhWAExXFzJ6/sPMb+mtZw/iTBaYIsyxgPN/F6+QmaB+lFseWFRzlxcB8drWbs1lNFT+xWC7s/ft/zva6ujnHjxnlci4MhHFmBFwOvhXocgBUrbNx//Qky8x4HzgO6K+fsNiv3XruAVlM9sixTVvoW1XuM3HvtAra88CiH9paxbf0znu23rX+m1zJ/sDqcfLCvjnf31kak84ZgeOiwOnhzdw07qkyDMlV0v/jKSt/ycw8tAF1dXUELgZC1XkajcZvBYKgK9Tg2G7z5pga4lLZmgP9yrdkHlALbgH8BzT73d1+0HZs3eEYObtzL1JoYHnq7wu82Vda1UdPcxUXTsshNCd09WTByOdrUydav6+gIYa7fYqrn5Qd/yw/X/JnktMxe6/sWKinA2cAMYKLXR4NancuVV17JQw89FFSbIkIHkJKSQlJSPHAB8HPgWcAIWICpwM3AP4FG4D/Ar4G8Po8Xl5jMZMN5qFTdNbJFFyzx/O/v9KDNYueNipPsPNQkFIRRivGwmTe/PBlS54e+dVqgPI+jx09FGxePMvL9AfA0sBtoArYADwM3AguAXCAVhyOO5OTkgEx/3gyp3asvreXHH3/MWWedhdP5BfCF15oYwIBSf3QxihSc7/o8iiIM/hd4GThVDquzrYUDxo97naes9C3KSt9CE6OlaOGlHNpbxrsvPcF3bxq4qnnpnjb2Hanj/MJkrB2RrR+I9NJgEPltNJvNWB1OPq5q4XBTV0jHuu/qc7HbTnmfunVamhgtd7zyCServ+Fvq34NfA94EOVZ13odwYJKs4tJc+JpMX3Kyaq3mXxGDnf/eimbN1/BkSNHgrYIDKkA6MtmmZWVxdVXX80//vGPHmtswA7X514gEbgYWAFcApzj+jyMIgieRhk5+Eaj1eGw27DbrBi3vg6AcevrGLe+jiZGy4Nvlffb/mYHlB62ckZ2fETbsKHvax1JRHIbmzrt/OuIjSarmoSEhJCOdfNjr/CXm7/fa7ndZufu5XcCNwEnONUdHcC/Ucdsx2HbRtaYJuqP7Sc18wekZsLJg++RmrGC+fPP5PLLvxtS2yLG86WtrY0JEyb08GxSodHpsHV1urcCNro+CcClwI+Bi1x/fwzsRBEImwAnSBLIMpJKhd1qIS45BZxOutpbkWWZGF0sM+aX8N2frvKrnR1WO+8eaMGhS6R4TPDaV0HkUtXQzjtfm9DGhq73WXNZcbe3v0Iqio7rJmCca5kN2Izy3L4N1ONwGRnqjip/vXVbOzb/L2PH/i+xsbEhjabCYQVYrvwxLA/lOBs2bGD69OlkjSkESUJSqZBlZz++4u1I0gZ+9WQbowouJUb3JIqO4AwUo8Re4HqSUnPIzp9A4cwzyB47gc4WM51tLZ75vM3SRWx8ok+lTF/IsswnlY28u7cWa5hdPwXDi/FwE1u+qsUWoC9/i6mev/7qap749dWe+X3vzp8FPAQcBh5A6fxVwBpgLPBd4AUKZowlY/RY5eXlhSSpkFzuyTG6WC6//HL2798fxK88RTisAO5Xcshs2LCBuQuWMH7WmcxbeiWfbXmNPTtKiUtIRpLAXN/d6WHOokv5/N1Xqal+B3gHbexdWLuuBH4HTAFepNV0iFbTXdQefhnwfVM/2/Iqy27+Q8Dtraxrw9RmZenMbFLjtQPvIIhYbA4npfvr+bauLaj9t61/hqMHdgOKHb/u2CEycsfScPwIdlsKcDuKAs89qtgKrAU+gB55/7Sx8UwomkfjyaOgUiE7ncoL0ekEWZnK2q0WEhMTg1b+uYmYKYCbH65+xJMRaNnNd7Ls5jsBeP2Je/hsy6unLoQk9bKXWrtMwN+A54GrUSTrVOAldHF3Yum8BXiv1zllWWbNZcUD6gB8Yeqw8lrZCS6alsW49PiA9xcMP21ddt75qob6NsvAG/fA1xD/1HMZB/wWWM0pn5Z/orz9d/Y6lkqtwemwk5Y9mtamRuYtXUHdsWrazI0kpqRjqjmGBFz/h7/y+buv0dDQEHB7exJxAqAv2swmzr7kKs/IwFR7nJgYLXs//1ARCN2wA/8P+AdwDZL0IJbOCcC7KP4Ev0KZIoBKpWb2BRdzwfd/wtOrru/TRtsfFruDd3bXMK8wFUO+qJU6kqhtsbD5q5qgffnXvLiVp35zLaba415LJeBalI4+xrXsbeD3wJ4+j+VO9OGe6/enmF528x/4wbTEoNrszYgRANff+bjnf+9RQe/O742TuMS36Gx7DbVmJQ77b1FMLBXAY8A9OJ3tVH75BXabneo9Rt585gFazSaPIBjIecONjMyOKhP1bVYWT8kkRtQriHgq69r4YF899n6fob7xreCbguLHcp7r+y6UUcCHrtGrslSSFB1Xek4eWWMncLDic2w2C7LTGbBiOhRG9FPaZjaRmp3LrPO+w6zzvkOMLpYYXSyzzvsOkqT8tM62FsCCw/4AMB54CuVnryIxpR64klZTPV99ovhY7/5kK9V7jNz3w0VA4O7ElXVtbNx1gtau8EaHCcLLzkNNvLe3LujO32KqJ6dwCqnZua4lOuCPwJconb8W+BGKH8uHxCUmUzjzDFKzc0nLzuVXT77G2ZdcRU7BFH5y91PMKbkUZNkzvw9UMR0sI2YE4AvvUUFPWkz1vPPcI3z16TavgIomFK/C/wGeps18JvAq8BaKWeakZ39ZdrJqyXTPd7c7sdt5oz8a2ixsMB5j6YxRjE6JDe7HCQYFu8PJ9gMNHKgNzpnLPSI89PUur9HnfOAFYDIA2rh/EKO9i66Oo8QlpqPV6hg9fmqv59U9kgXlZTZv6QrPFLe1KfT5vT+MaAHQH8lpmejiE3DYrB5fADcqVQVTz7qD6j1z6GhdDVyGIrVvRdEdKKg1MUiShN1mDXhY1mlz8EbFCRZOzmBazvDkfBd0p9PqYPOeGk42B+/Zd9//WeTV8bXA3cBtgBqV+hs0mlsYlV9DZu50dm2vYub8Er8sTL6muEPBaSsA4JRUrTtWzYmD++hsa0FSqXA6Hezd8QGKCeZlFMvBJcDfUbwMbwRO4rArnhg9h2Xt7f5lF3bKMqX762los3Lu+HSRlnwYMbVbeXt3DS1dgYXwut/4R/Z/6XkeFGagPDuzUTz3HsDpuAerw8qR/XBkv2IS9B45BmNlGmxOawHgLVVfuvdWklIzmLf0Sj7a9Hcqv/yM5sY6kI+jOGBcBzyOIgh2Az8F3gRwTSEkvvz4fRZf83PUusDMfV8ea6apw8bF07LQiTLnQ85RUwfv7q3DYg88tHvLC49SvceIOsbbz+MW4BGUef9BlGfn0z6PUbzwkiFR6AXDiFYCBsL1dz7OspvvZHThFK767QNMPfMCJPB4Vkmql4FpIG0FMlDstf8XiGf8rLNITEmls7U54NwCbo6YOnht14lhqwEXrXx1vIU3d9cE3Pnvu/rcbrH5DpsVJSx3E/BXlM7/N5QRgO/OL0kqJEkaMoVeMESNAOiJe3pQOPMML1fhBJAvRvETsAA/A8o4uLuLNrOSBGLH5g3cvfxMVi2dwfGD+3u5f/ZHk8tp6IipY5B/ncDplPn3tw3865v6AcO4fYWG997nLKAcuAIl8nQZiuK4HUmlRuujnN3Mcxczb+kKWpsaQ/sxg8hpPQXoD+/pwZrLiqk9XOm19nHgQ2A9MB34DPhv4LlTm8gyLz+4kobjhwDFXOiPssdid/DW7hrOGZ8mgokGCavdyXtf13K40T9B623qdd/D6edcyO5/b3FtcSvKkD8GJVx9BXDIs7/sdJCQkkYCMGbSDACOfbMHp8M5pAq9YIg4ATB/XDLahGQ6rQ46rA6au+zUt1oCDs4IhDUvbuWd5x5hz45SbJYuNFodCcm1NDcYgL+gjASeRUlY8jNAUQK6Oz8EpuxxBxM1tllZMCkj7MVLoxlzh43NX9Vg8mOq1dORp3c2qXiU+36t6/tjaLR3YbcqJkRJpWbS3Pkk6tOwdLT3a5YOFQmJOK2aeK2aBJ2aRJ2GnjEEwRBxAqAgLZasrO5vRqdTxtRhpbbFwonmLg41doQ1V5/bZGi3WpScATYr1q4OoAtlmPcRynzvWmAusBy3K7EbSaVi+rxFXPHLO/w+776aVpo6bCyZke26oYJQONrUybt7av2e7695cSvPrLqOhhNHui1Pzc6lxZSMw/YKyhy/jdwJf6azdR1tLcqxJZUKZCdpWaODCiQbCJ1GTY4+llHJOrKTdWQn6XopkMORFnxEPHUqlURGoo6MRB3TRyfjdMqcaO6iurGDg/XttAZo2vGFWyfwxfuvI8synW3ejiLrgTKUoMcZwOfAT1CciBRkp5Ok1PSAlT01LV1sMB5nyfRs4TQUAruPNfNRZaPfadt8u/EqNNXOAF5Bids/AFzB8cp93bZx+wIEG0nqiySdhoKMBAoz4slNiRsSs/GIHHuqVBJ5qXGcNyGd6+eN4XuzcxifmYgqhDpzbivB71/6gOIFlxCj69kZDwBnojgKJQAbgD+j0SYw9SwlvZOp9kRQ5+6w2nmj4gRfHvOd8FTQN3aHk2376/j3tw0B5Wxc8+JWihdc0ivmHn6DkpgjFaS3gDORVAeYMb+EGfNLPM9FjC6W4oWXcMfLH4bUfo1KxdRRSSyfk8uP5udzwaQMxqTFD5nPyIgUAN5IksTYtHiWzsjmx/PzObswjbgQbO3e0wG37dcdVwCdKDbfW1AyuKxEG/sf4pMm4bDbSMseHXTREqcs89G3DWz9um5Q9R2nE61ddl4vP8G+k/679brvT2tTI5Vffo4EIKlQzHr/gxKjrwbuAvlyoAVkmaTUDBJT0j3TxFD99VPjtVwwMYOfzB/L4qlZ5OiHZ/Q3IqYA/hKvVWPIT2V2np59J1vZdbQ5qOmBt1/2+j/dRu2Rg548BMrc7ymmzUtl747/oqNlNmXb7gDKuimR/LUK9ORAbSv1bRaWTM8mLUEkGemLo02dvL+3ls4AdUFujf8rD99Ga1ODEkAWO56O5udQks62owj5TV57SR5TXqj++hmJOs7IT2F8ZkJEVEY+rQSAmxi1ill5emaMTmZ/bSufVzfRZvE/Os9bm5uZV0DhzDO6JWZIHZVHV+sOfvXkEp77vZn2lmnAxyguxEpi01BcQE3tVl4tO86CSRmicnEPlMo8Zj4PME17zzl/7ZGDANgsk7FZ3kaJ2z+Mkpn3S892kiRxx8vbe73pAzXvZSXpOGNcKoUZoSUYDTenpQBwo1JJTMtJZmJWIl8ea6bssJLqORB8mXba29s9mWKnz7+PL95bgJLg8WWgGPgdMboYJs89x1O2LNChos1VneiYuZMLJmaI/AIolXm27qvjaBCOVG5Tb/fo0MtQBHYiSor5K4DuU7c5iy4NyYsvSadhXmEak7MTI+KN35OQBYArGagZmGM0Gh8OvUnhJ0atwpCfyrScZL441MSe4y3IIdpQ3UEi2tg45i7eQmXFCZobbgdWArOwWa6i/tgh6o5WseWFRzHVnggq29C+k63UNFu4cGoW2cm6gXc4TTna1MkHX9cFnbknOS2Tio/e9YrkW4WSoFMF/B2kn4GsRAm6p3vZ+eOxdAbntalVq5ibn8LsPH1EC++QWuYuCmo0GrcB5nAXCQ038Vo1CyZlcOXc0WQmhtaZ3HPJtOzR7Nr+Ns0N9wALURJBXAh8Tu0Rbbcahu4kI4HS1GFl467j7DzUhNMZXdWJHE6ZTw828mZFYCW4vZWx7v+Vzh+DkjPyYZTHfw1wPbPOPf9Uso4nXsPwne+TmVsQlHPP+MwErj1rDIb81Iju/BD6CGAFSkwtKPmNF6PkQIpospNj+cHcXL460cJnVaaApgW+vMdO8R+UtOT/BOaguBBfC7wDnEoyEoxewCnLfFZt4lBjBxdOzSIlPiag/Uci9a0WPthXT2N74Mk63QJ6ywuP8k35DtqaGph9/rXs/eyX2K1nAx3AD9Fn7mDC7O9h6Wjn9y9u9ez/3Rt/F3BBkESdhvMnZjA+M7Lm+f0RqgBIAUxe39P729gfz6WhLBmVo4WLCmP5uLqFE83+PWS3Pv1P3n3hUb4p+7hb6eZTHAXORTEprUAJKb4deAiNVsfUsxZy0XX/7XdOgZ4cbG/nUK2JotGJzMyJ9+n7EOllt6D/Njplma9OdlB+vA1ngPUY7736HBy2U5afUxl6J/LlR38AJgEnUKmvwOn4AkkazaX/dTtAt3vS1RVY0pApWfEYxujQyu3U1QV3bwMlHPc5IkqDBbtduCjIy+Gr4y3852DjgDb4hIQE4pKScdisHntwbzqBq1C0yQ+g1Hubhd36U7RaLZse/0M3fYC/iUe92d/kpN5qY9HkDEb5sCFHctktN77aWNPcxYffNNDQ5iQuPvA060XnL/FRXvsCFLNeGlDOzHOfZPHVv/OY8vp60/szAkjQaiiZkkn+MKWED/U+hyoAzChXFZTRQOTGPQ7AzNxkxqbFUbq/nuPmzn63bffyE3jxj/9Ne0sT1q7ObmnHFB5ESQP9D+BqVOqpfFN+A21N5d38BHxFo/lDY7uFjbtOMG10EmeNSyVhBMcTdNkcfFpl4usTrUEpaPt27b0BeAZl7v8WcA1ffdLOvs9fDjlDz4SsRBZOyiB2BCd5CfWJ2YCS9hSgECXp/ohFHxfDFUU5A9qZr7rtYc/b4fcvbvVkG5pxzmLeePKPPYJL3gbmAW/hdBTRatoMLPMReRac74CMzN4TLXxb2+bROo8knE6ZvSdb+bzaFLBTjzdrXtzKG0/dy94d25X7JmlAfhAlJTcoHn63AU7SsnP55aM9C9H6T4xaddr4aIQkAIxG4y6DwmLnkWqyAAAVjklEQVTAbDQaI14BOBCSJHHGuFTyUuN4/+s6vzwJvTXFE4rm9Youg69R4gheBUqAfwG/QJ+xGWtXR9CFSr2xOpzsqDKx53gLE/SQniGjjuAchLIsc6C2jc+rTTR3hh7MlZyWSf2xQy6hnQjyepTisTbg5yiafwVT7XFPee5ARwHpCTounp512nhphqM24LPhaEikkaOP5WpDLqUHGjhY73+9OHcFo96jARNKafM/oySYWEdzw9Mo2YdOVS8ONX1Uq8XOJ9XtVDYfoXhsCtNykiLKFCXLMlUNHZTuNWGRgqvD11Nn0n34n48y1J+Fcs2/jyJw6ebOXXTBkoAF7bScZM6fmB5R1zNURu6kcQjQxahZOiObiqOx/Odgo18aafdowPec1I7S4ctR8g3+AiW8+Epkp2Ih2bF5A4uv+XnIOeRaLXY++raBnYeamJmbzLScZJJih+92W+1O9p5sYfexFlq6bLR32EgI8C3q7vhp2aM9OpPF1/wcXUIidrMJOB8lZDsT2I8yAlAyPWXk5tN44ogn30MgglajUrFw8ukx5O/J6SPKBpGiMXqWFY8OKGnHmhe3os/IRlL5UhC9hPKwHnP9LQPmkZadiyRJbFv/TK+owmCjDDttDr441MRLO47w5pcn+bauDccQORPJsszRpk4+PFDP/3x6mE8qGwNOy+3Nff9nEdV7jJSVvuXJz3jvtQtoN5tQMjVtQ+n876LoXdydfxyWzg7mLV3BLY+tDyhPX3JsDMvnjD4tOz+IEYDf5OhjucqQx/v76tjvhw0/OS2TqWdewOfvvtbHFjtR9KcbUfwG/o2p9jfAU90UhFteeJSrfvtg0JYCNzIyR0wdHDF1oNOoyUuNY1x6HPlp8WG1HtgcTmqau6hq6KCyvp2OIF13vekveYei3X8cZZ4PSu6+1bhLwWfnj+/l0edvIM/YtHi+My1rRGv5B0IIgACI06r53qxRaO0dHGwZeHt3WPGMcxbz0h9vUUyF3agFFqKNexpr543Akyhvrp+heKopjizedu1wFJqw2B0crG/jYH0bEhJpCVqykrSurEta0hO0xGkHfui7bA6aO+00d9qobbVw0txFXZsloCg9f+iZs/EUo1EE6NkoWZxvxF3ZSa3RcOZ3ltPa1BCUO++snASWzBoVkQE84UQIgACRJAnDmCQmSwl8sK//5B3eD96cRZfy+buvoY7RYrdayM4fzzW3PcxfbrkSa+dNQCmwDvghSh66FcC+XseM0cWGFGXYExmZxnaLy932VGINlSQRF6MmTqsmNkaFLCt++U5ZxuGENos9qEIbgeCt7HMnaXGXeYtLvJTOtueAbOAISp7GnUiSiqIFioIvmGujUakomZJJitRx2nd+EDqAoBmfmcAP5uaSEuefT757NHDLY+s5+5KryMwtIDElnbyJM1wZhzagmAr3AzNRpgjXdzuG2/Ow/tghjh74KugiJf7glGXarXYa2iwca+rkuLmTmpYu6loVYTHYnR+6+/Pv/ngrc0ouY+Y5FwG/prNtE0rnL0VJ1LoTUOItgrWkJOo0LCsezaTsxDD+ishGCvdwrS9Wr14t33PPPQNuV1dXF/FurN5ttNgcbN1Xz6HGwP2/Nz35Rz7b8irpo8fScPywa2kC8DRKVhpQ6hX+Ancqcl/0nA545yuIVPprY99z/jSUGIvLXN//hBJn0VsYBTpFyk6KZenMU9mZR9pz2B933XUXDz30kM/hjBgBhIguRs13Z2ZjyE/1e581lxWzasl0dmzegCzLXp0flI5+vevjTk+1CziT2PhE4hKT0WiVUGZ3Yso1XlFspwNrXtxKxuixPZaeA1SgdP4mlOQdqwEH+ozskK7JxKxElhXnRGVqdiEAwoAkSZxdmMbF07P9chJxZ6T1zjCrz8gmJTOH5PQs0nPGEKN9FcVK8BVKBNt/6OpYSWdbh1+JKYM1Gw41Pdu55rJi7r12gZcDlRr4PYozzxhgB1CEEnKtYO3q6Bac5e8UQELizHGpXDw9O2qLs0Tnrx4kJmYl8v3i0STF9q8X6FmIxG61MO2sBdz+923c+fKHrH7hPSYbziU1uxWlJt1aFH3t3Sg5ByZhtyra9r5SkXubDSOZbeufoXqPkUd/uYwnfn01E4rne62diJJr8X6U3/8Qit9Ed1frzjbFlVp2Ovy28atVEhdNy+KsgrQBtz2dib4xzyCTmaRjxdxc3t1b229UoXfmYV8ZZt0WhBZTPU/95lpMte+gOBCdCVSgjVuLtfMeEvXdpx59lbsa7vr0Pd1377v63G7tbDebXA49ABJK6vWHgDgUh6mfcCr3jEKMLhabpatbDIU/b/54rYZLZmT7DKOONoQAGATitGoun53Dx5WN7D7uu9iHP44p3TvzcRTrwOPAj7B23gFcRlnpTZSVnsoydPNjr/Ds73+KtevUVEEXl8BNDzzn8xxDhfeIZPE1Pyd73CROVO714TMwBcVN+gLX95dQx9yGw6a4SmtitOgzc7BbLbQ01gY87E9P0HHprFHD6hYdSYgpwCChUklcMCmDhZMzg65Y5NYVSCr3bWoBfoySc/AgSsDLp8ATOJ1KPcXP332VjpYmT2ETu9VCe7OJz7a86usUg6YrcB+3p8LT7b57/Ns96OK8k2jEoyRP2Y3S+WuBy4EfeTo/gMNuY1LxPMZMmhGwa29+ejzL54wWnd8LIQAGmRmjk1lWPJp4beAPnVtXgCz3qFK0DWU08CCKy+vNOB37WbXkOXZsPlXQwtFjKrBqyXRWX1rU7Rx96Qr8EQz9beM+7uzzL2bG/BKfTjVdHW7T5uUoIdNrUJR+fwOmoaRTU6L40kfnc+MD6zyd3V3KbXThFJbdfOeA3n6z8vR8d8YotBrxyHsjROEQkKOPZcXcXLbsqaW2NbBcc76qFCl0omjH1wOPoeRjfQz4JYp5bBO+ykdPn7eQp1ddz9EDu32XxpYk7nz5Q79iD3oO619+8Le9jtstPZfLi0+SVMiyE+VN/wDgVvztQvHp/8Kzizt6b1LxPCYVn82k4rMDuHqKheb8ienMyh1ZiVKGCiEAhojEWA3LinP48JsG9tf4X8vOV5WiE9UHOFl9AGtnB0rKsQuBpSiBMNNQ/OO/Qulcr+IOjAHY/YnbPi5RvOASj399jC6W5PQsGk8c4d5rF3i2dwsGtSaGsVNm88M1f+bBH13kW3gACSlpTJlWzIGy//Tw28eTMk2Wi1A0+xe7VtSRlPo0rU33dmsrgN1qQZJUfkfveaPTqFkyPYsxacOTr28kIATAEKJRq7hwahaZiTo+Oeh/KWs33sLg9Sfu6TGv3wK8jxIQczvKFOEV4I8oOfD/gTJqcCNT/q/Nnm82SxeNvTIZ4dGwq1Rqdm1/my0vPOqJv1ep1Did3b3w2s0m9nxaCoA6Rus1DVEhSZeC9Ctk5wLXsmb0mf9g8pyd7Nu5mbkll9LWbOLbXZ/idDoC1u57o4+L4bszR502mXsGCyEAhoGiMXoyErW8F0RxSzdtZhP6jFF0tjV7RRk6UDToL6B4EK5GsaU/h5KJaD1KwNEuNFodcYnJdLa1KG9ZlQpkxZdeUqmRnQ40MVpsli7KPzwlKLyH9D07f0+Uzp+NkiH5ZmR5gmtW0oqSqPNPNNeb2LX9lIny9SfuQZadIVXgHZMax8XTs0/rMN5wERaNSKRXBIpE8lLjuMqQR3ZScLbo6+98nDv+XykJel+OLFaUjj4ZuAbFe06PMr8uA8qxW39Hq2ksdqvyhpadTte8HGSng+z88dzyl1eYW/I99BnZ/bYlbVRejyXpnErQcRz4CzABpXbMr4A84He4S0rYbVaPctI7aCqQxB1uZuXpuWxWjuj8fhKO2oCLUdS240NvTnTh1gt8XNnInhN+JBjwQe74qUwxnMe8pVfy0aa/s/uTrdgs3iOCV1yf6Sgpsq9DcaUtAu4BTgLvoSjeKlDMcB20NNYjy0oCzQmzz6Ks9C0v5d0p0kcXkJ4zD4ctg+bGySjJTWZy6t1iQamM9CJKrj7f4dNFFywB/POP8IVKklg4OYNpOcl+7yMIT1LQbQaDoSocjYlGNGoVCydnkqOP5cMDDdid/pcpg+4d5qrfPkDVHiNNtcdRa2Jw2G2o1GpUKjV2217gNyimtsXAJa7PWBTfgh+7juIAKulsO87Tq9qxdl7HiSoLeROvw9Jpo/7YcRTvvLFAIY0nCmg8EdejVRZgO0qI8z8B385Q3rgTnwTjsZig1bBkRjY5wrMvYIZUBxBppcGCZTDamKaCRfk6tleaae4KPo1Wdv5Exs04g7OWLMf4wRtU7vqU5oYaNDFal+beAmx2fUBJSroIpaz5HBQrwmRgMlbXQMLSAce+7eekUgPxSZV0tLwF0g6QPwP6NndKkork9GxSR+Vy+OtyRd8QZNm07CQtC/PjUVtaqKsLbhTli2h5DkVpsCAZjDZmAQV5o/jXNw0cqPXfVOjNDfc85Ym1Hz+9mJfuvZXmxtp+curtcX3c6FAUh9koCTazUOb0DsCKSuXE6WxH8c+vAqpBbqPD3ff6MGykjx6L6eRRj2Vg4pz5aDRqDu8t89j6E5P1jMrL9/u3zs7Tc+74dFSDVP8gGp7DAQWAwWC4ycfiKldJcEGY0WpUXDQtizGpcfz724YBaxUOxPV3Pk6Lqb6PnHq+sNBbKJwiwBmKgiThdDi6BT811deg0Wj6DYjqixi1ikWTM6Mqc89gMaAAOF0Lf0Q6U3OSyE7W8f7XdTS0BV4e2xvv8GNvRZ67UEZadi5N9SeRg+rd/ZM2Ko/Jc8+ltanBo9RbdvOdvTIC+avwy0zU8Z3pWaTGC/t+OAjZDGgwGJYrfwzLw9AegRdpCVqunDOaojEpSIQ2zHWb13715GukZueSlp3Lr554jbMvuQoZenV+fcYoV1ae7ueVVGr0GdnMXfw9ppx5Qfd1kgpJpUIbG88P16wlO38CNqvFL199f5iVp+fKubmi84eRcFgBNqL4ngoGAY1axXkT0ilIj2fb/nq/ahX6wrsD/t4rXdaym+/kpXtvBWDMpBkAHPtmD6PHTyUxJa1XnUPZ6WDC7LO4auUDvHTvrWSPHU/t0SokSUJ2Ojl76QpP7MDs8y8mHMTGqCmZkklhRmTnORyJCE/AEUJeahzXnJHHR5UN7DsZnIKwL/p6O790762cfclV1B2rps3cSGdbCy2NdVTvKfPs99K9t1I484yA5/H+Mj4zgQWTMon3o06BIHCEABhBaDUqFk/JYmJmIv/6piGkMlv+0FedQ1PtcVYtmd7LZh+I485A6DRqzp+YftqW5IoURHD0CCQ/PZ5rzsxj7tiUISle4SuJ6WBmIx6fmcA1Z+aJzj8EiBHACCVGrWL++HQmZSujgZPNgeUZCARfSUxDLWPuC31cDBdMzCA/XYTvDhVCAIxwMhJ1LJ+TS2VdG/85aBq0acFASUxDQaNSMTc/hTlj9FGbnnu4EALgNGFCViIFGQnsPt7Mv/b2nY04WIIN0ukPlSQxLSeJM8alRmVRjkhAXPXTCLVKonhMCpmaLmqsOiqONgedb2AwkZAYnx7HhbPHkBLvX21FweAgBMBpiFatwpCfyuw8PXtPtFJ+1EybJfgAo3ARo1YxdVQSs/KSsbWZReePAIQAOI2JUasoGqNnVm4yh00d7D3ZyqHGjoBTkYWKPi6GGaOTmZ6ThM6VqKOubUibIOgDIQCiAJVKoiAjgYKMBNosdg7UtlHd0E5Ni2XQhEGSTsOErEQmZiWSnawblHMIQkcIgCgjUadh7tgU5o5NodPq4LCpgyOmTupaLZg7bMh9xfMOQGyMmtH6WEanxDJaH0dWknZIfBQEoSEEQBQTp1UzZVSSx+HGanfS2G6loc1Ku9VOp9VBp81Bl82JJClae7VK+STpNCTHadDHxqCPi0EfpxEdfgQiBIDAg1ajIkcfK1JrRRHC60IgiGKEABAIohghAASCKEYIAIEgihECQCCIYoQAEAiimHCUBnOnDR9vNBp/F+rxBALB0BHSCMBVF3CbK3V4oeu7QCAYIYQ6BShEKTQHSpmYwhCPJxAIhpCQpgA9iobMQakG2SeiNuDQEOntg8hvY6S3DyKoNqDBYJgD7DIajbv6207UBhw6Ir19EPltjPT2QeTUBlwsFIACwcgj5NqABoPhJqPR+LDr/8WiaKhAMHIIhxXgTwaD4aDBYGgKU5sEAsEQEaoScBuQGqa2CASCIUZ4AgoEUYwQAAJBFCMEgEAQxQgBIBBEMUIACARRjBAAAkEUIwSAQBDFCAEgEEQxQgAIBFGMEAACQRQjBIBAEMUIASAQRDFCAAgEUYwQAAJBFCMEgEAQxQgBIBBEMUIACARRjBAAAkEUE47SYO7CIBeKzMACwcgiHElBr3TlBpzjqg8gEAhGCOFICupOA144UGEQgUAQWYSrMtBtwM8G2u6uu+4Kx+kEAkGYkGRZDsuBDAbDa8CNRqMx8ouqCQQCIMTSYO45v2voXwXcBDwc3iYKBILBItTSYIsB97w/BdgZjkYJBIKhIaQpgMFgSAF+4Po612g0DqgHEAgEkUPYdACC4cFgMCwHzMAcd5HWPra7rb/1gsjHYDDM6cvS5u9z0JOwWAGCZaBGB/ujhrB9bv3I+OFwgvLSwWwzGAyFfT0gLn+NCxkG/Ywf13AOUAhgNBo3DnHz3G3w9zksHKha9mDhuod/A8b7WOfXc+CLYXMF9m40YO7pRDTQ+gho32Jgm+uBKPTyiBxKVqA8mKAoYYejDX3i5z1c4+r4hcPhSObnc1jlWl81XM5u7vP3sTro52A4YwEGavRwP9wDnb/Qa1mV6/tQkwKYvL6n99zA9TbY1nP5ENHvNXS9WXcCGI3Gh4fJkcyf5+xPrr+R6uw24HPQF8MpAAZqdNA/Kkz0e36j0fis13BwDmAcqoYFSNownnuge3gGkG4wGOa4nMmGg4Hu8y6UN39Tj+1OC0Q0YIi4hoS7hunNYOZUB08BGr1XDvPb318a3dfONSKIKFyWLjPwIPCcwWAYjpHeQPT7HPTHcAqAgRod9I8KE/6ef/EwRkFu4NTUoxBXXIbroQVlXr3cpaxMG4b560DXsJFT81ozyohgqBmojTcBD7qUgzcCESOkvO6zz+fAH4ZTAAz08Ab9o8LEQO3DYDDc5NYaD4cS0OvNuRgwe41CSl3rN3pp1lN8HGKwGegabvRaP1yOZAPeZzeuazksru6u0ZGhxyjJfZ/7eg4GZFj9AFxvpiq8zCsGg6HMaDTO7Wt9pLTPdbFfQ5kXpnEqLFrghZ/32AScMVwjKT/aeJtrfdpwmQEHC+EIJBBEMUIJKBBEMUIACARRjBAAAkEUIwSAQBDFCAEgEEQxQgAIBFGMEAACQRTz/wF6dvcOKkRjEAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 288x216 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "test_x = torch.linspace(0, 1, 51)\n",
    "observed_pred = likelihood(model(test_x))\n",
    "\n",
    "with torch.no_grad():\n",
    "    # Initialize plot\n",
    "    f, ax = plt.subplots(1, 1, figsize=(4, 3))\n",
    "\n",
    "    # Get upper and lower confidence bounds\n",
    "    lower, upper = observed_pred.confidence_region()\n",
    "    # Plot training data as black stars\n",
    "    ax.plot(train_x.numpy(), train_y.numpy(), 'k*')\n",
    "    # Plot predictive means as blue line\n",
    "    ax.plot(test_x.numpy(), observed_pred.mean.detach().numpy(), 'b')\n",
    "    # Shade between the lower and upper confidence bounds\n",
    "    ax.fill_between(test_x.numpy(), lower.numpy(), upper.numpy(), alpha=0.5)\n",
    "    ax.set_ylim([-3, 3])\n",
    "    ax.legend(['Observed Data', 'Mean', 'Confidence'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
